import taichi as ti

from ..utils.common import Vec3, Mat3x3
from ..config import GAMMA

__all__ = [
    "ACES_toneMapping",
    "reinhard_toneMapping",
    "sRGB2lRGB",
    "lRGB2sRGB",
    "lRGB2XYZ",
    "XYZ2lRGB",
]


def lrgb2luminance(lRGB: Vec3):
    return lRGB * Vec3(0.2126, 0.7152, 0.0722)


@ti.func
def ACES_toneMapping(lRGB: Vec3) -> Vec3:
    """
    ACES 色调映射
    https://64.github.io/tonemapping/#extended-reinhard-luminance-tone-map"""
    # ACESInputMat = Mat3x3(
    #     0.59719, 0.35458, 0.04823,
    #     0.07600, 0.90834, 0.01566,
    #     0.02840, 0.13383, 0.83777,
    # )

    # ACESOutputMat = Mat3x3(
    #     +1.60475, -0.53108, -0.07367,
    #     -0.10208, +1.10813, -0.00605,
    #     -0.00327, -0.07276, +1.07602,
    # )
    # lRGB = ACESInputMat @ lRGB
    # a = lRGB * (lRGB + 0.0245786) - 0.000090537
    # b = lRGB * (0.983729 * lRGB + 0.4329510) + 0.238081
    # return ACESOutputMat @ (a / b)
    lRGB *= 0.6
    A, B, C, D, E = 2.51, 0.03, 2.43, 0.59, 0.14
    return lRGB * (A * lRGB + B) / (lRGB * (C * lRGB + D) + E)


@ti.func
def reinhard_toneMapping(lRGB: Vec3) -> Vec3:
    return lRGB / (lRGB + 1)


def reinhardJodie_toneMapping(lRGB: Vec3) -> Vec3:
    l = lrgb2luminance(lRGB)
    tv = lRGB / (1 + lRGB)
    return ti.mix(lRGB / (1 + l), tv, tv)


@ti.func
def sRGB2lRGB(sRGB: Vec3) -> Vec3:
    return ti.select(sRGB < 0.04045, sRGB / 12.92, ((sRGB + 0.055) / 1.055) ** GAMMA)


@ti.func
def lRGB2sRGB(lRGB: Vec3) -> Vec3:
    return ti.select(
        lRGB < 0.0031308, lRGB * 12.92, 1.055 * lRGB ** (1 / GAMMA) - 0.055
    )


lRGB2XYZ = Mat3x3(
    [0.41239080, 0.35758434, 0.18048079],
    [0.21263901, 0.71516868, 0.07219232],
    [0.01933082, 0.11919478, 0.95053215],
)
XYZ2lRGB = lRGB2XYZ.inverse()
